# SPDX-License-Identifier: Apache-2.0
"""
Saves each worker's model state dict directly to a checkpoint, which enables a
fast load path for large tensor-parallel models where each worker only needs to
read its own shard rather than the entire checkpoint.

Example usage:

python save_sharded_state.py \
    --model-path /path/to/load \
    --quantization deepspeedfp \
    --tensor-parallel-size 8 \
    --output /path/to/save

Then, the model can be loaded with

llm = Engine(
    model_path="/path/to/save",
    load_format="sharded_state",
    quantization="deepspeedfp",
    tensor_parallel_size=8,
)
"""
import dataclasses
import os
import shutil
from argparse import ArgumentParser
from pathlib import Path

from sglang import Engine, ServerArgs

parser = ArgumentParser()
ServerArgs.add_cli_args(parser)

parser.add_argument(
    "--output", "-o", required=True, type=str, help="path to output checkpoint"
)
parser.add_argument(
    "--file-pattern", type=str, help="string pattern of saved filenames"
)
parser.add_argument(
    "--max-file-size",
    type=str,
    default=5 * 1024**3,
    help="max size (in bytes) of each safetensors file",
)


def main(args):
    engine_args = ServerArgs.from_cli_args(args)
    model_path = engine_args.model_path
    if not Path(model_path).is_dir():
        raise ValueError("model path must be a local directory")
    # Create LLM instance from arguments
    llm = Engine(**dataclasses.asdict(engine_args))
    Path(args.output).mkdir(exist_ok=True)
    llm.save_sharded_model(
        path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
    )

    # Copy metadata files to output directory
    for file in os.listdir(model_path):
        if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
            if os.path.isdir(os.path.join(model_path, file)):
                shutil.copytree(
                    os.path.join(model_path, file), os.path.join(args.output, file)
                )
            else:
                shutil.copy(os.path.join(model_path, file), args.output)


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)
